# -*- coding: utf-8 -*-
#
# plot_progress.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST.  If not, see <http://www.gnu.org/licenses/>.

r"""Generate a GIF of the network solving a Sudoku puzzle
----------------------------------------------------------------
This scripts takes one of the .pkl files generated by
``sudoku_solver.py`` and generates a GIF showing the progress
of the network solving the puzzle.

Note that the script generates the images individually, storing
them to disk first, assembling them into a GIF and then,
by default, deleting the images and folder.

See Also
~~~~~~~~

:doc:`sudoku_solver`

:doc:`helpers_sudoku`

:Authors: J Gille
"""
import os
import pickle
import sys
from glob import glob

import helpers_sudoku
import imageio
import matplotlib.pyplot as plt
import numpy as np


def get_progress(puzzle, solution):
    valid, boxes, rows, cols = helpers_sudoku.validate_solution(puzzle, solution)
    if valid:
        return 1.0
    return (boxes.sum() + rows.sum() + cols.sum()) / 27


# Name of the .pkl files to read from.
in_files = ["350Hz_puzzle_4.pkl"]
temp_dir = "tmp"  # Name of directory for temporary files
out_file = "sudoku.gif"  # Name of the output GIF
keep_temps = False  # If True, temporary files will not be deleted


if os.path.exists(out_file):
    print(f"Target file ({out_file}) already exists! Aborting.")
    sys.exit()

try:
    os.mkdir(temp_dir)
except FileExistsError:
    print(f"temporary file folder ({temp_dir}) already exists! Aborting.")
    sys.exit()


image_count = 0
# store datapoints for multiple files in a single list
lines = []
for file in in_files:
    with open(file, "rb") as f:
        sim_data = pickle.load(f)

    solution_states = sim_data["solution_states"]
    puzzle = sim_data["puzzle"]

    x_data = np.arange(0, sim_data["max_sim_time"], sim_data["sim_time"])
    solution_progress = []

    lines.append([[], [], f"{sim_data['noise_rate']}Hz"])

    n_iterations = len(solution_states)

    for i in range(n_iterations):
        solution_progress.append(get_progress(puzzle, solution_states[i]))

    for i in range(n_iterations):
        px = 1 / plt.rcParams["figure.dpi"]
        fig, ax = plt.subplots(figsize=(600 * px, 400 * px))
        ax.set_axis_off()
        current_state = solution_states[i]

        lines[-1][0] = x_data[: i + 1]
        lines[-1][1] = solution_progress[: i + 1]
        progress = plt.subplot2grid((3, 3), (1, 0), rowspan=2, colspan=1)
        progress.set_ylim(0, 1)
        progress.set_xlim(0, 10000)
        progress.set_xlabel("simulation time (ms)")
        progress.set_ylabel("performance")
        for x, y, label in lines:
            progress.plot(x, y, label=label)
        progress.legend()

        stats = plt.subplot2grid((3, 3), (0, 0), 1, 1)
        stats.axis("off")
        stats.text(0, 1, "Time progressed:", horizontalalignment="left", verticalalignment="center", fontsize=16)
        stats.text(
            0,
            0.7,
            f'{i * sim_data["sim_time"]}ms\n',
            horizontalalignment="left",
            verticalalignment="center",
            fontsize=12,
            color="gray",
        )
        stats.text(0, 0.5, "Noise rate:", horizontalalignment="left", verticalalignment="center", fontsize=16)
        stats.text(
            0,
            0.2,
            f'{sim_data["noise_rate"]}Hz\n',
            horizontalalignment="left",
            verticalalignment="center",
            fontsize=12,
            color="gray",
        )

        ax = plt.subplot2grid((3, 3), (0, 1), rowspan=3, colspan=2)
        if i == 0:
            # repeat the (colorless) starting configuration several times
            helpers_sudoku.plot_field(puzzle, puzzle, ax, False)
            image_repeat = 8
        else:
            helpers_sudoku.plot_field(puzzle, current_state, ax, True)
            image_repeat = 1

        if i == len(solution_states) - 1:
            # repeat the final solution a few more times to make it observable
            # before the gif loops again
            image_repeat = 15

        plt.subplots_adjust(wspace=0, hspace=0, left=0.1, right=1.05)
        for j in range(image_repeat):
            plt.savefig(os.path.join(temp_dir, f"{str(image_count).zfill(4)}.png"))
            image_count += 1
        plt.close()

filenames = sorted(glob(os.path.join(temp_dir, "*.png")))

images = []
for filename in filenames:
    images.append(imageio.imread(filename))

imageio.mimsave(out_file, images, duration=250)
print(f"gif created under: {out_file}")

if not keep_temps:
    print("deleting temporary image files...")
    for in_files in filenames:
        os.unlink(in_files)
    os.rmdir(temp_dir)
